# Owner(s): ["module: nn"]
import math
import unittest
import itertools
import warnings
from itertools import product

import torch

import torch.autograd.forward_ad as fwAD
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch.testing._internal.common_dtype import floating_types_and, floating_and_complex_types_and
from torch.testing._internal.common_utils import run_tests, \
    skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_SCIPY, TEST_WITH_ROCM, \
    download_file, parametrize as parametrize_test, subtest, \
    instantiate_parametrized_tests, set_default_dtype
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
from torch.testing._internal.common_nn import NNTestCase, _test_module_empty_input
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \
    dtypesIfCUDA, precisionOverride, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
    skipCUDAIfRocm, skipCUDAIfRocmVersionLessThan, skipCUDAIfNotMiopenSuggestNHWC, \
    onlyNativeDeviceTypes, largeTensorTest, skipMeta, \
    disableMkldnn, skipCPUIfNoMkldnn, disablecuDNN, skipCUDAIfMiopen, skipCUDAIfNoMiopen

from torch.testing import make_tensor
from torch.testing._internal.common_utils import gradcheck, gradgradcheck, \
    GRADCHECK_NONDET_TOL
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32

AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()


if TEST_SCIPY:
    import scipy.signal
    import scipy.ndimage


class TestConvolutionNN(NNTestCase):
    _do_cuda_memory_leak_check = True
    _do_cuda_non_default_stream = True

    def test_conv_backcompat(self):
        from torch.serialization import SourceChangeWarning

        # This file was generated by running on PyTorch 1.0.1 on Python 2:
        #
        #     import torch
        #     from torch import nn
        #     m = nn.Conv2d(1, 1, 1)
        #     torch.save(m, 'legacy_conv2d.pt')
        #
        # NB: This Pickle also contains some Unicode data!
        path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', SourceChangeWarning)
            m = torch.load(path, encoding='utf-8')
        input = torch.randn((1, 1, 1, 1), dtype=torch.float)
        self.assertEqual(m(input).size(), (1, 1, 1, 1))

    def test_invalid_conv1d(self):
        for dtype in [torch.half, torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]:
            module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype)
            input = torch.randn(1, 3, 4).to(dtype)
            with self.assertRaisesRegex(RuntimeError,
                                        r'Calculated padded input size per channel: \(4\). ' +
                                        r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'):
                module(input)

            # Negative stride check
            module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype)
            input = torch.randn(1, 3, 4).to(dtype)
            with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
                module(input)

    def test_mismatch_shape_conv2d(self):
        for dtype in (torch.float, torch.cfloat):
            x = torch.randn(1, 10, 1, 28, 28, dtype=dtype)
            w = torch.randn(6, 1, 5, 5, dtype=dtype)

            with self.assertRaisesRegex(RuntimeError,
                                        r'Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got ' +
                                        r'input of size: \[1, 10, 1, 28, 28\]'):

                F.conv2d(x, w)

    def test_conv2d_discontiguous_weight(self):
        for dtype in (torch.float, torch.cfloat):
            # Test for https://github.com/pytorch/pytorch/issues/55781
            x = torch.ones(64, 16, 16, 16, dtype=dtype)
            weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2).to(dtype)[:, :, :, ::2]
            self.assertFalse(weight.is_contiguous())
            y = torch.nn.functional.conv2d(x, weight, None)
            if torch.backends.mkldnn.is_available():
                # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
                with torch.backends.mkldnn.flags(enabled=False):
                    y_ = torch.nn.functional.conv2d(x, weight, None)
                    self.assertEqual(y, y_)
            self.assertEqual(y.sum(), 4186112.)

    def test_invalid_conv2d(self):
        for dtype in [torch.half, torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]:
            module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype)
            input = torch.empty(1, 1, 4, 4).to(dtype)
            self.assertRaises(RuntimeError, lambda: module(input))

            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True)
            input = torch.randn(1, 3, 1, 1)
            with self.assertRaisesRegex(RuntimeError,
                                        r'Calculated padded input size per channel: \(1 x 1\). ' +
                                        r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'):
                module(input)

            # Negative stride check
            module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype)
            input = torch.randn(1, 3, 4, 4).to(dtype)
            with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
                module(input)

            # Zero stride check
            module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype)
            input = torch.randn(1, 3, 4, 4).to(dtype)
            with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
                module(input)

    def test_invalid_conv3d(self):
        for dtype in [torch.half, torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]:
            module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype)
            input = torch.empty(1, 1, 4, 4, 4).to(dtype)
            self.assertRaises(RuntimeError, lambda: module(input))

            # Negative stride check
            module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2)
            input = torch.empty(1, 1, 4, 4, 4)
            with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
                module(input)

    def test_conv_invalid_groups(self):
        with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'):
            torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0)
        with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'):
            torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1)
        with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'):
            torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2)

    def test_Conv1d_module_same_padding(self):
        # Compare module against functional: without strides/dilation, asymmetric padding
        x = torch.rand(1, 1, 20)
        module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10,
                           padding='same')
        expect = F.conv1d(x, module.weight, module.bias, padding='same')
        self.assertEqual(expect, module(x))

        # Test dilation, symmetric padding
        module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10,
                           padding='same', dilation=2)
        expect = F.conv1d(x, module.weight, module.bias, padding='same', dilation=2)
        self.assertEqual(expect, module(x))

        # Test non-zero padding_mode, requiring explicit padding
        module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10,
                           padding='same', padding_mode='replicate')
        x_padded = F.pad(x, [4, 5], mode='replicate')
        expect = F.conv1d(x_padded, module.weight, module.bias, padding='valid')
        self.assertEqual(expect, module(x))
        self.assertEqual(x.size(), expect.size())

        # Test connstruction with invalid padding string raises
        with self.assertRaisesRegex(ValueError, 'Invalid padding string'):
            module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, padding='foo')

        # Test connstruction with same padding and strides raises
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2)

    def test_Conv2d_module_same_padding(self):
        # Compare module against functional:
        # without strides/dilation, both symmetric and asymmetric padding
        x = torch.rand(1, 1, 9, 20)
        module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(5, 10),
                           padding='same')
        expect = F.conv2d(x, module.weight, module.bias, padding='same')
        self.assertEqual(expect, module(x))

        # with dilation, symmetric padding
        module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 4),
                           padding='same', dilation=(1, 2))
        expect = F.conv2d(x, module.weight, module.bias, padding='same', dilation=(1, 2))
        self.assertEqual(expect, module(x))

        # Test non-zero padding_mode, requiring explicit padding
        module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 4),
                           padding='same', padding_mode='reflect')
        x_padded = F.pad(x, [1, 2, 1, 1], mode='reflect')
        expect = F.conv2d(x_padded, module.weight, module.bias, padding='valid')
        self.assertEqual(expect, module(x))
        self.assertEqual(x.size(), expect.size())

        # Test connstruction with invalid padding string raises
        with self.assertRaisesRegex(ValueError, 'Invalid padding string'):
            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='foo')

        # Test connstruction with same padding and strides raises
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2)
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 3))
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(4, 1))

    def test_Conv3d_module_same_padding(self):
        # Compare module against functional:
        x = torch.rand(1, 1, 4, 4, 4)
        # without dilation, both symmetric and asymmetric padding
        module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4),
                           padding='same')
        expect = F.conv3d(x, module.weight, module.bias, padding='same')
        self.assertEqual(expect, module(x))

        # with dilation, both symmetric and asymmetric padding
        module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4),
                           padding='same', dilation=(3, 2, 1))
        expect = F.conv3d(x, module.weight, module.bias, padding='same', dilation=(3, 2, 1))
        self.assertEqual(expect, module(x))

        # Test non-zero padding_mode, requiring explicit padding
        module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4),
                           padding='same', padding_mode='circular')
        x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode='circular')
        expect = F.conv3d(x_padded, module.weight, module.bias, padding='valid')
        self.assertEqual(expect, module(x))
        self.assertEqual(x.size(), expect.size())

        # Test connstruction with invalid padding string raises
        with self.assertRaisesRegex(ValueError, 'Invalid padding string'):
            module = nn.Conv3d(in_channels=3, out_channels=33, kernel_size=10, padding='foo')

        # Test connstruction with same padding and strides raises
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2)
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 1, 3))
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 4, 1))
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(5, 1, 1))

    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
    def test_thnn_conv_strided_padded_dilated(self):
        for convfn, dims, transposed in (
                (torch.nn.functional.conv2d, 2, False),
                (torch.nn.functional.conv_transpose2d, 2, True),
                (torch.nn.functional.conv3d, 3, False),
                (torch.nn.functional.conv_transpose3d, 3, True)):
            for stride, padding, dilation in (
                    (2, 0, 1), (1, 1, 1), (2, 1, 1), (1, 0, 2)):
                kwargs = {"stride": stride, "padding": padding, "dilation": dilation}
                inp_shape = (1, 2) + dims * (4,)
                weight_shape = (2, 2) + dims * (1,)
                inputs = torch.randn(inp_shape, dtype=torch.double, device="cuda", requires_grad=True)
                weight = torch.randn(weight_shape, dtype=torch.double, device="cuda", requires_grad=True)
                bias = torch.randn(2, dtype=torch.double, device="cuda", requires_grad=True)
                with torch.backends.cudnn.flags(enabled=False):
                    res = convfn(inputs, weight, bias, **kwargs)
                res_cpu = convfn(inputs.cpu(), weight.cpu(), bias.cpu(), **kwargs)
                self.assertEqual(res, res_cpu)
                with torch.backends.cudnn.flags(enabled=False):
                    torch.autograd.gradcheck(
                        lambda x, w, b: convfn(x, w, b, **kwargs),
                        (inputs, weight, bias)
                    )
                    torch.autograd.gradcheck(
                        lambda x, w, b: convfn(x, w, b, **kwargs),
                        (inputs.cpu(), weight.cpu(), bias.cpu())
                    )

    def test_Conv2d_inconsistent_types(self):
        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float)
        weights = torch.randn(1, 1, 3, 3, dtype=torch.double)
        # inconsistent types should raise an exception
        self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
        # but it should work with the same type
        nn.functional.conv2d(inputs.float(), weights.float())

    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
    def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self):
        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
        weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
        bias = torch.randn(1, dtype=torch.double, device="cuda")

        with torch.backends.cudnn.flags(enabled=False):
            # inconsistent types should raise an exception
            self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
            self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights.float(), bias))

            # but it should work with the same type
            nn.functional.conv2d(inputs.float(), weights.float(), bias.float())

    def test_Conv2d_1x1(self):
        in_channels = 2
        out_channels = 2
        mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double)
        input = torch.randn(1, in_channels, 5, 5, requires_grad=True, dtype=torch.double)
        for enabled in (False, True):
            with torch.backends.mkldnn.flags(enabled=enabled):
                gradcheck(F.conv2d, (input, mod.weight))

    def test_Conv2d_OneDNN(self):
        def run_once(group_val=24, dilation=1):
            ifm = torch.ones([1, group_val, 6, 6], dtype=torch.float32)
            weights = torch.ones([group_val, 1, 3, 3], dtype=torch.float32)
            op = torch.nn.Conv2d(
                in_channels=group_val,
                out_channels=group_val,
                kernel_size=[3, 3],
                stride=[2, 2],
                padding=[1, 1],
                dilation=[dilation, dilation],
                groups=group_val,
                bias=False,
                padding_mode='zeros'
            )

            op.weight.data = weights
            res = op(ifm)
            grad_in = torch.ones(res.shape, dtype=torch.float32)
            res.backward(grad_in)
            return op.weight.grad

        for gorup_val in (24, 48, 23, 25):
            for dilation in (1, 2):
                with torch.backends.mkldnn.flags(enabled=False):
                    without_onednn = run_once(gorup_val, dilation)

                with torch.backends.mkldnn.flags(enabled=True):
                    with_onednn = run_once(gorup_val, dilation)

                self.assertEqual(without_onednn, with_onednn)

    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
    def test_cudnn_non_contiguous(self):
        x = torch.randn(192, 16, 50).cuda()
        x = x.permute(0, 2, 1).contiguous().permute(0, 2, 1)
        m = torch.nn.Conv1d(
            in_channels=16,
            out_channels=32,
            kernel_size=2,
            bias=True).cuda()
        result = m(x)

    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
    def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self):
        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
        weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
        bias = torch.randn(1, dtype=torch.double, device="cuda")

        with torch.backends.cudnn.flags(enabled=True):
            # inconsistent types should raise an exception
            self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
            self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights.float(), bias))

            # but it should work with the same type
            nn.functional.conv2d(inputs.float(), weights.float(), bias.float())

    def test_Conv2d_missing_argument(self):
        c = nn.Conv2d(3, 3, 3)
        self.assertRaises(TypeError, lambda: c(None))

    def test_Conv2d_backward_twice(self):
        input = torch.randn(2, 3, 5, 5)
        c = nn.Conv2d(3, 3, 3)
        o1 = c(input)
        o1.sum().backward()
        self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
                               lambda: o1.sum().backward())

    def test_conv_modules_raise_error_on_incorrect_input_size(self):
        for dtype in [torch.half, torch.bfloat16, torch.double, torch.float]:
            modules = [nn.Conv1d(3, 8, 3).to(dtype), nn.ConvTranspose1d(3, 8, 3).to(dtype),
                       nn.Conv2d(3, 8, 3).to(dtype), nn.ConvTranspose2d(3, 8, 3).to(dtype),
                       nn.Conv3d(3, 8, 3).to(dtype), nn.ConvTranspose3d(3, 8, 3).to(dtype)]

            invalid_input_dims = [(1, 4), (1, 4),
                                  (2, 5), (2, 5),
                                  (3, 6), (3, 6)]

            for invalid_dims, module in zip(invalid_input_dims, modules):
                for dims in invalid_dims:
                    input = torch.empty(torch.Size((3, ) * dims))
                    self.assertRaises(RuntimeError, lambda: module(input))

    def test_conv_shapecheck(self):
        def test(should_raise, module, input_size, dtype):
            input = torch.empty(3, *input_size).to(dtype)
            if should_raise:
                self.assertRaises(RuntimeError, lambda: module(input))
            else:
                # just run it to ensure no exception raised.
                module(input)

        for dtype in [torch.half, torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]:
            # Conv1d
            test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype)
            test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype)
            test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype)
            test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype)
            test(False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype)

            # Conv2d
            test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype)
            test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype)
            test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype)

            # Conv3D
            test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype)
            test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype)
            test(False, nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype), (1, 2, 2, 2), dtype)

    def test_ConvTranspose2d_output_size(self):
        m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
        i = torch.randn(2, 3, 6, 6)
        for h in range(15, 22):
            for w in range(15, 22):
                if 18 <= h <= 20 and 18 <= w <= 20:
                    output = m(i, output_size=(h, w))
                    self.assertEqual(output.size()[2:], (h, w))
                else:
                    self.assertRaises(ValueError, lambda: m(i, (h, w)))

    def test_ConvTranspose2d_output_size_downsample_upsample(self):
        b, c, hid_c = 2, 3, 2
        for h in range(13, 24):
            for w in range(13, 17):
                for k in range(2, 5):
                    for d in range(1, 5):
                        for s in range(1, 4):
                            for p in range(3):
                                conv = nn.Conv2d(
                                    in_channels=c,
                                    out_channels=hid_c,
                                    kernel_size=k,
                                    stride=s,
                                    padding=p,
                                    dilation=d,
                                )

                                t_conv = nn.ConvTranspose2d(
                                    in_channels=hid_c,
                                    out_channels=c,
                                    kernel_size=k,
                                    stride=s,
                                    padding=p,
                                    dilation=d,
                                )

                                i = torch.randn(b, c, h, w)

                                out = t_conv(conv(i), output_size=i.shape)

                                self.assertEqual(out.size()[2:], i.size()[2:])

    def test_ConvTranspose3d_correct_output_size(self):
        # Check that ConvTranspose3d can take a 5d output_size.
        m = nn.ConvTranspose3d(2, 2, 2)
        i = torch.rand(1, 2, 1, 1, 1)
        out = m(i, output_size=(1, 2, 2, 2, 2))

    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
    def test_ConvTranspose2d_half_cublas_gemm(self):
        with torch.backends.cudnn.flags(enabled=False):
            inputs = torch.randn(1, 1, 16, 16, device='cuda', dtype=torch.half)
            deconv = nn.ConvTranspose2d(
                1, 1, 3, stride=2, padding=1, output_padding=1).cuda().half()
            output = deconv(inputs)
            output.mean().backward()

    # For https://github.com/pytorch/pytorch/pull/1273
    # Almost identical to the above `test_Conv2d_naive_groups`
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
    @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
    def test_Conv2d_groups_nobias(self):
        dev_dtypes = [("cpu", torch.float)]
        if TEST_CUDA:
            dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
        if AMPERE_OR_ROCM:
            dev_dtypes += [("cuda", torch.bfloat16)]
        for device, dtype in dev_dtypes:
            m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype)
            i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
            output = m(i)
            grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
            output.backward(grad_output)

            m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
            m1.weight.data.copy_(m.weight.data[:2])
            i1 = i.data[:, :2].contiguous().requires_grad_(True)
            output1 = m1(i1)
            output1.backward(grad_output[:, :2].contiguous())

            m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
            m2.weight.data.copy_(m.weight.data[2:])
            i2 = i.data[:, 2:].contiguous().requires_grad_(True)
            output2 = m2(i2)
            output2.backward(grad_output[:, 2:].contiguous())

            self.assertEqual(output, torch.cat([output1, output2], 1))
            self.assertEqual(i.grad.data,
                             torch.cat([i1.grad.data, i2.grad.data], 1),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)
            self.assertEqual(m.weight.grad.data,
                             torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
                             atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)

    # Almost identical to the above `test_Conv2d_naive_groups`
    # Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16
    # See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686
    # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
    @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
    def test_Conv2d_groups_nobias_v2(self):
        torch.manual_seed(123)
        dev_dtypes = [("cpu", torch.float)]
        if TEST_CUDA:
            dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
        if AMPERE_OR_ROCM:
            dev_dtypes += [("cuda", torch.bfloat16)]
        for device, dtype in dev_dtypes:
            m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype)
            i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
            output = m(i)
            grad_output = torch.randn(2, 16, 4, 4, device=device, dtype=dtype)
            output.backward(grad_output)

            m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
            m1.weight.data.copy_(m.weight.data[:8])
            i1 = i.data[:, :2].contiguous().requires_grad_(True)
            output1 = m1(i1)
            output1.backward(grad_output[:, :8].contiguous())

            m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
            m2.weight.data.copy_(m.weight.data[8:])
            i2 = i.data[:, 2:].contiguous().requires_grad_(True)
            output2 = m2(i2)
            output2.backward(grad_output[:, 8:].contiguous())

            self.assertEqual(output, torch.cat([output1, output2], 1))
            self.assertEqual(i.grad.data,
                             torch.cat([i1.grad.data, i2.grad.data], 1),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)
            self.assertEqual(m.weight.grad.data,
                             torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
                             atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)

    # CPU-only test for group conv3d fast implementation using bmm
    # See: https://github.com/pytorch/pytorch/pull/36355
    def test_Conv3d_groups_nobias(self):
        torch.manual_seed(123)
        m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=False).to("cpu", torch.float)
        i = torch.randn(2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True)
        output = m(i)
        grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
        output.backward(grad_output)

        m1 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
        m1.weight.data.copy_(m.weight.data[:8])
        i1 = i.data[:, :2].contiguous().requires_grad_(True)
        output1 = m1(i1)
        output1.backward(grad_output[:, :8].contiguous())

        m2 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
        m2.weight.data.copy_(m.weight.data[8:])
        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
        output2 = m2(i2)
        output2.backward(grad_output[:, 8:].contiguous())

        self.assertEqual(output, torch.cat([output1, output2], 1))
        self.assertEqual(i.grad.data,
                         torch.cat([i1.grad.data, i2.grad.data], 1),
                         atol=dtype2prec_DONTUSE[torch.float], rtol=0)
        self.assertEqual(m.weight.grad.data,
                         torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
                         atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float])

    def test_Conv3d_groups_wbias(self):
        torch.manual_seed(123)
        m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=True).to("cpu", torch.float)
        i = torch.randn(2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True)
        output = m(i)
        grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
        output.backward(grad_output)

        m1 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
        m1.weight.data.copy_(m.weight.data[:8])
        m1.bias.data.copy_(m.bias.data[:8])
        i1 = i.data[:, :2].contiguous().requires_grad_(True)
        output1 = m1(i1)
        output1.backward(grad_output[:, :8].contiguous())

        m2 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
        m2.weight.data.copy_(m.weight.data[8:])
        m2.bias.data.copy_(m.bias.data[8:])
        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
        output2 = m2(i2)
        output2.backward(grad_output[:, 8:].contiguous())

        self.assertEqual(output, torch.cat([output1, output2], 1))
        self.assertEqual(i.grad.data,
                         torch.cat([i1.grad.data, i2.grad.data], 1),
                         atol=dtype2prec_DONTUSE[torch.float],
                         rtol=dtype2prec_DONTUSE[torch.float])
        self.assertEqual(m.weight.grad.data,
                         torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
                         atol=dtype2prec_DONTUSE[torch.float],
                         rtol=dtype2prec_DONTUSE[torch.float])
        self.assertEqual(m.bias.grad.data,
                         torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
                         atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float])

    def test_conv_tbc(self):
        with set_default_dtype(torch.double):
            inp = torch.randn(9, 4, 5, requires_grad=True)
            weight = torch.randn(3, 5, 6, requires_grad=True)
            bias = torch.randn(6, requires_grad=True)

            gradcheck(lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3))

    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
    @skipIfRocmVersionLessThan((4, 3))
    @skipIfNotMiopenSuggestNHWC
    def test_grouped_conv_cudnn_nhwc_support(self):
        # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
        input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last)
        weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last)
        out = torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4)
        input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last)
        out_transpose = torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4)

    @unittest.expectedFailure
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
    def test_conv_cudnn_memory_layout_dominance(self):
        # desired behavior here is to have the memory_layout of conv.weight to
        # dominante the layout of output.
        # which is not the same as current behavior, we'll fix this in
        # following up PRs and remove the `expectedFailure` tag
        input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True)
        conv = nn.Conv2d(8, 4, 3).cuda().float()

        out = conv(input)
        self.assertTrue(out.is_contiguous())

        input = input.contiguous(memory_format=torch.channels_last)
        out = conv(input)
        self.assertTrue(out.is_contiguous())

        conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last)
        out = conv(input)
        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))

        input = input.contiguous()
        out = conv(input)
        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))

    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
    def test_cudnn_noncontiguous_weight(self):
        # Noncontiguous weights must be contiguous() before being
        # passed to cuDNN
        input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3)
        weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2)
        weights2 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2).contiguous()
        self.assertEqual(F.conv1d(input, weights1, bias=None, stride=2, dilation=2),
                         F.conv1d(input, weights2, bias=None, stride=2, dilation=2))

    def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient='input'):
        for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
            for batch, stride, padding, chan_in, chan_out, dilation in \
                    product([1, 2], [1, 2], [0, 1, 2], [2], [3], [1]):

                for has_bias in [True, False]:
                    input_shape = [batch, chan_in]
                    weight_shape = [chan_out, chan_in]
                    for _ in range(dim):
                        input_shape.append(inp_size)
                        weight_shape.append(kern)

                    input = torch.randn(input_shape, requires_grad=True)
                    weight = torch.randn(weight_shape, requires_grad=True)
                    if has_bias:
                        bias = torch.randn([chan_out], requires_grad=True)
                    output = func_forward(input, weight, stride=stride, padding=padding, dilation=dilation, bias=bias)

                    gradient_o = torch.randn(output.shape)
                    gradient_w = torch.autograd.grad(output, input if (gradient == 'input') else weight, gradient_o)

                    self.assertEqual(gradient_w[0],
                                     func_backward(
                                     input_shape if (gradient == 'input') else input,
                                     weight_shape if (gradient == 'weight') else weight,
                                     gradient_o,
                                     stride=stride,
                                     padding=padding,
                                     dilation=dilation))

    def test_grad_conv1d_input(self):
        self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, 'input')

    def test_grad_conv1d_weight(self):
        self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, 'weight')

    def test_grad_conv2d_input(self):
        self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, 'input')

    def test_grad_conv2d_weight(self):
        self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, 'weight')

    def test_grad_conv3d_input(self):
        self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, 'input')

    def test_grad_conv3d_weight(self):
        self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, 'weight')

    @unittest.skipIf(not torch._nnpack_available(), "NNPACK unavailable")
    def test_nnpack_conv(self):
        for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
            for batch, stride, padding, chan_in, chan_out in \
                    product([1, 2, 3, 4], [1, 2], [0, 1, 2], [2], [3]):

                for has_bias in [True, False]:
                    input_shape = [batch, chan_in]
                    weight_shape = [chan_out, chan_in]
                    for _ in range(2):
                        input_shape.append(inp_size)
                        weight_shape.append(kern)

                    input = torch.randn(input_shape, requires_grad=True, dtype=torch.float)
                    weight = torch.randn(weight_shape, requires_grad=True, dtype=torch.float)
                    if has_bias:
                        bias = torch.randn([chan_out], requires_grad=True, dtype=torch.float)
                    output = torch._nnpack_spatial_convolution(input, weight, stride=stride, padding=padding, bias=bias)
                    output_expected = torch.nn.functional.conv2d(
                        input, weight, stride=stride, padding=padding, bias=bias)
                    self.assertEqual(output, output_expected, atol=3e-4, rtol=0)

                    gradient_o = torch.randn(output.shape, dtype=torch.float)

                    grads = torch.autograd.grad(output, [input, weight], gradient_o)
                    grads_expected = torch.autograd.grad(output_expected, [input, weight], gradient_o)
                    for gr, gr_expected in zip(grads, grads_expected):
                        self.assertEqual(gr, gr_expected, atol=3e-4, rtol=0)

    def test_conv_padding_mode(self):
        with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
            nn.Conv2d(3, 3, 3, padding_mode="xyz")

        with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
            nn.Conv2d(3, 3, 3, padding_mode=3)

        with self.assertRaisesRegex(ValueError, "Only \"zeros\" "):
            nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect")

    def test_functional_grad_conv(self):
        # Conv 1D
        input = torch.randn(1, 1, 5, requires_grad=True)
        weight = torch.randn(1, 1, 3, requires_grad=True)
        output = F.conv1d(input, weight, dilation=2)
        grad_output = torch.randn(output.shape)

        grad_input_autograd, grad_weight_autograd = torch.autograd.grad(output, (input, weight), grad_output)

        grad_input_functional = torch.nn.grad.conv1d_input(input.shape, weight, grad_output, dilation=2)
        self.assertEqual(grad_input_functional, grad_input_autograd)

        grad_weight_functional = torch.nn.grad.conv1d_weight(input, weight.shape, grad_output, dilation=2)
        self.assertEqual(grad_weight_functional, grad_weight_autograd)

        # Conv 2D
        input = torch.randn(1, 1, 5, 5, requires_grad=True)
        weight = torch.randn(1, 1, 3, 3, requires_grad=True)
        output = F.conv2d(input, weight, dilation=2)
        grad_output = torch.randn(output.shape)

        (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output)

        grad_input_functional = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, dilation=2)
        self.assertEqual(grad_input_functional, grad_input_autograd)

        grad_weight_functional = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, dilation=2)
        self.assertEqual(grad_weight_functional, grad_weight_autograd)

        # Conv 3D
        input = torch.randn(1, 1, 5, 5, 5, requires_grad=True)
        weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True)
        output = F.conv3d(input, weight, dilation=2)
        grad_output = torch.randn(output.shape)

        (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output)

        grad_input_functional = torch.nn.grad.conv3d_input(input.shape, weight, grad_output, dilation=2)
        self.assertEqual(grad_input_functional, grad_input_autograd)

        grad_weight_functional = torch.nn.grad.conv3d_weight(input, weight.shape, grad_output, dilation=2)
        self.assertEqual(grad_weight_functional, grad_weight_autograd)

    def test_functional_grad_conv2d(self):
        BATCH_SIZE = 4
        IN_CH = 8
        OUT_CH = 16
        SPATIAL = 32

        def _test_conv2d(stride, kernel_size, groups, dilation):
            padding = kernel_size // 2

            input = torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL).uniform_(-8.0, 8.0).requires_grad_(True)

            weight = torch.empty(OUT_CH, IN_CH // groups, kernel_size,
                                 kernel_size).uniform_(-4.0, 4.0).requires_grad_(True)

            output = F.conv2d(input, weight,
                              stride=stride, padding=padding, dilation=dilation, groups=groups)

            grad_output = torch.randn(output.shape)

            (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output)

            grad_input_functional = torch.nn.grad.conv2d_input(input.shape, weight, grad_output,
                                                               stride=stride, padding=padding, dilation=dilation, groups=groups)
            self.assertEqual(grad_input_functional, grad_input_autograd)

            grad_weight_functional = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output,
                                                                 stride=stride, padding=padding, dilation=dilation, groups=groups)
            self.assertEqual(grad_weight_functional, grad_weight_autograd)

        strides = [1, 2]
        kernel_sizes = [1, 3, 5]
        groups = [1, 2, 4]
        dilates = [1, 2]

        for s, k, g, d in product(strides, kernel_sizes, groups, dilates):
            _test_conv2d(s, k, g, d)


class TestConvolutionNNDeviceType(NNTestCase):
    def run_conv_double_back_test(self, kern, stride, padding, chan_in, chan_out, batch_size,
                                  inp_size, dilation, no_weight, groups=1, use_cuda=False,
                                  use_bias=True, dtype=torch.double):
        if use_cuda:
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")

        x = torch.randn(batch_size, chan_in, inp_size, inp_size, device=device,
                        dtype=dtype, requires_grad=True)
        weight = torch.randn(chan_out, chan_in // groups, kern, kern, device=device,
                             dtype=dtype, requires_grad=not no_weight)
        if use_bias:
            bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True)
        else:
            bias = None

        def func(*inputs):
            if use_bias:
                lx, lweight, lbias = inputs
            else:
                lx, lweight = inputs
                lbias = None
            # We disable cudnn during forward to avoid finite difference imprecision issues
            with cudnn.flags(enabled=False):
                out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
            return out

        if use_bias:
            inputs = x, weight, bias
        else:
            inputs = x, weight

        dummy_out = func(*inputs)
        grad_y = torch.randn_like(dummy_out, device=device, dtype=dtype, requires_grad=True)

        # Issue #15353: test mkldnn double backward, don't run gradgradcheck due
        # to imprecision issues
        if dtype == torch.float:
            g, = torch.autograd.grad(dummy_out.sum(), x, create_graph=True)
            return g.requires_grad

        return gradgradcheck(func, inputs, (grad_y,))

    @onlyCUDA
    @skipCUDAIfNoCudnn
    @dtypes(*floating_and_complex_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
    def test_Conv2d_deterministic_cudnn(self, device, dtype):
        inputs = torch.randn(2, 3, 5, 5, device=device, dtype=dtype, requires_grad=True)
        with cudnn.flags(enabled=True, benchmark=True, deterministic=True):
            conv1 = torch.nn.Conv2d(3, 3, 3).to(device, dtype)
            conv2 = torch.nn.Conv2d(3, 3, 3).to(device, dtype)
            conv2.bias.data.copy_(conv1.bias.data)
            conv2.weight.data.copy_(conv1.weight.data)
            out1 = conv1(inputs)
            out2 = conv2(inputs)
            self.assertEqual(out1, out2, atol=0.0, rtol=0)
            y = torch.randn(out1.size(), device=device, dtype=dtype)
            out1.backward(y)
            out2.backward(y)
            self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0)
            self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0)

    @onlyCUDA
    @dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
    def test_Conv2d_large_workspace(self, device, dtype):
        # These sizes require huge cuDNN workspaces. Make sure we choose a
        # reasonable algorithm that does not run out of memory
        sizes = [
            (1, 256, 109, 175),
            (1, 256, 80, 128),
            (1, 256, 120, 192),
        ]

        def run_test(benchmark):
            with torch.backends.cudnn.flags(enabled=True, benchmark=benchmark):
                conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(device, dtype)
                for size in sizes:
                    x = torch.randn(size, device=device, dtype=dtype)
                    out = conv(x.detach().clone().requires_grad_())
                    out.backward(torch.ones_like(out))

        run_test(benchmark=False)
        run_test(benchmark=True)

    @onlyCUDA
    @dtypes(torch.half, torch.float)
    def test_ConvTranspose2d_large_output_padding(self, device, dtype):
        net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
            .to(device=device, dtype=dtype)
        net2 = torch.nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\
            .to(device=device, dtype=dtype)
        net3 = torch.nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1)\
            .to(device=device, dtype=dtype)
        x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True)
        x = net1(x)
        x = net2(x)
        x = net3(x)
        x.backward(torch.randn_like(x))
        torch.cuda.synchronize()

    @onlyCUDA
    @dtypes(torch.float, torch.double, torch.half)
    # Very similar to test_Conv2d_naive_groups but with special care to handle
    # the number of groups == number of input channels
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
    @tf32_on_and_off(0.01)
    def test_Conv2d_depthwise_naive_groups(self, device, dtype):
        for depth_multiplier in [1, 2]:
            m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(device, dtype)
            i = torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_()
            output = m(i)
            grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype) / 2
            output.backward(grad_output)

            offset = 1 * depth_multiplier

            m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
            m1.weight.data = m.weight.data[:offset].clone()
            m1.bias.data = m.bias.data[:offset].clone()
            i1 = i.detach()[:, :1].clone().requires_grad_()
            output1 = m1(i1)
            output1.backward(grad_output[:, :offset].contiguous())

            m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
            m2.weight.data.copy_(m.weight.data[offset:])
            m2.bias.data.copy_(m.bias.data[offset:])
            i2 = i.detach()[:, 1:].clone().requires_grad_()
            output2 = m2(i2)
            output2.backward(grad_output[:, offset:].contiguous())

            self.assertEqual(output, torch.cat([output1, output2], 1),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)
            self.assertEqual(i.grad.data,
                             torch.cat([i1.grad.data, i2.grad.data], 1),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)
            self.assertEqual(m.bias.grad.data,
                             torch.cat([m1.bias.grad.data,
                                        m2.bias.grad.data], 0),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)
            self.assertEqual(m.weight.grad.data,
                             torch.cat([m1.weight.grad.data,
                                        m2.weight.grad.data], 0),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)

    @onlyCUDA
    @dtypes(torch.float, torch.double, torch.half)
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
    @tf32_on_and_off(0.01)
    def test_Conv3d_depthwise_naive_groups(self, device, dtype):
        for depth_multiplier in [1, 2]:
            m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(device, dtype)
            i = torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_()
            output = m(i)
            grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype) / 2
            output.backward(grad_output)

            offset = 1 * depth_multiplier

            m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
            m1.weight.data = m.weight.data[:offset].clone()
            m1.bias.data = m.bias.data[:offset].clone()
            i1 = i.detach()[:, :1].clone().requires_grad_()
            output1 = m1(i1)
            output1.backward(grad_output[:, :offset].contiguous())

            m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
            m2.weight.data.copy_(m.weight.data[offset:])
            m2.bias.data.copy_(m.bias.data[offset:])
            i2 = i.detach()[:, 1:].clone().requires_grad_()
            output2 = m2(i2)
            output2.backward(grad_output[:, offset:].contiguous())
            is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(0) == (8, 6)
            atol, rtol = (3e-4, 3e-2) if dtype == torch.float32 and is_cuda_sm86 else (dtype2prec_DONTUSE[dtype], 0)

            self.assertEqual(output, torch.cat([output1, output2], 1),
                             atol=atol, rtol=rtol)
            self.assertEqual(i.grad.data,
                             torch.cat([i1.grad.data, i2.grad.data], 1),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)
            self.assertEqual(m.bias.grad.data,
                             torch.cat([m1.bias.grad.data,
                                        m2.bias.grad.data], 0),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)
            self.assertEqual(m.weight.grad.data,
                             torch.cat([m1.weight.grad.data,
                                        m2.weight.grad.data], 0),
                             atol=atol, rtol=rtol)

    @onlyCUDA
    @dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
    def test_noncontig_conv_grad(self, device, dtype):
        # FIXME: remove after adding non-contiguous grad tests for all modules
        module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype)
        input = torch.randn(2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True)
        output = module(input)

        grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1]
        assert not grad.is_contiguous()
        output.backward(grad, retain_graph=True)
        self.assertIsNotNone(input.grad)
        result = input.grad.data.clone()
        input.grad.data.zero_()

        output.backward(grad.contiguous())
        self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0)

    @onlyCUDA
    @dtypes(torch.double)
    def test_conv_double_backward(self, device, dtype):
        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
            # Double backward only runs with DoubleTensor due to precision reason
            batch_size = 1
            for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]:
                for stride, padding, chan_in, chan_out, dilation in product([1], [2], [2], [3], dilations):
                    no_weight = stride == 2
                    result = self.run_conv_double_back_test(kern, stride,
                                                            padding, chan_in, chan_out,
                                                            batch_size, inp_size, dilation,
                                                            no_weight, use_cuda=True, dtype=dtype)
                    self.assertTrue(result,
                                    "Conv double backward test failed with parameters:" +
                                    "\nkern: " + str(kern) +
                                    "\nstride: " + str(stride) +
                                    "\npadding: " + str(padding) +
                                    "\nchan_in: " + str(chan_in) +
                                    "\nchan_out: " + str(chan_out) +
                                    "\nbatch_size: " + str(batch_size) +
                                    "\ninp_size: " + str(inp_size) +
                                    "\ndilation: " + str(dilation))

    def test_conv_double_backward_no_bias(self):
        kern = 3
        stride = 2
        chan_in, chan_out = 2, 4
        batch_size = 2
        inp_size = 5
        padding = 1
        dilation = 1
        no_weight = False
        use_bias = True
        result = self.run_conv_double_back_test(kern, stride,
                                                padding, chan_in, chan_out,
                                                batch_size, inp_size, dilation,
                                                no_weight, use_bias=use_bias)
        self.assertTrue(result,
                        "Conv double backward test failed with parameters:" +
                        "\nkern: " + str(kern) +
                        "\nstride: " + str(stride) +
                        "\npadding: " + str(padding) +
                        "\nchan_in: " + str(chan_in) +
                        "\nchan_out: " + str(chan_out) +
                        "\nbatch_size: " + str(batch_size) +
                        "\ninp_size: " + str(inp_size) +
                        "\ndilation: " + str(dilation))

    def test_conv_double_backward_groups(self):
        kern = 3
        stride = 1
        padding = 2
        chan_in, chan_out = 2, 4
        batch_size = 2
        inp_size = 6
        dilation = 1
        no_weight = False
        groups = 2
        result = self.run_conv_double_back_test(kern, stride,
                                                padding, chan_in * groups, chan_out * groups,
                                                batch_size, inp_size, dilation,
                                                no_weight, groups=groups)
        self.assertTrue(result,
                        "Conv double backward test failed with parameters:" +
                        "\nkern: " + str(kern) +
                        "\nstride: " + str(stride) +
                        "\npadding: " + str(padding) +
                        "\nchan_in: " + str(chan_in) +
                        "\nchan_out: " + str(chan_out) +
                        "\nbatch_size: " + str(batch_size) +
                        "\ninp_size: " + str(inp_size) +
                        "\ndilation: " + str(dilation) +
                        "\ngroups: " + str(groups))

    def test_conv_double_backward_stride(self):
        batch_size = 2

        # Cannot provide ggW when stride is > 1
        for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]:
            for stride, padding, chan_in, chan_out, dilation in product([2], [0, 1], [1], [2], dilations):
                no_weight = False
                self.run_conv_double_back_test(kern, stride,
                                               padding, chan_in, chan_out,
                                               batch_size, inp_size, dilation,
                                               no_weight)

    @dtypes(torch.float, torch.cfloat)
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
    def test_conv1d_same_padding(self, device, dtype):
        # Test padding='same' outputs the correct shape
        test_args = [
            # in_size
            range(50, 55),
            # kernel_size
            [1, 2, 3, 8],
            # dilation
            range(1, 4),
            # stride
            [1],
        ]
        for in_size, k_size, dilation, stride in itertools.product(*test_args):
            x = torch.rand(1, 1, in_size, device=device, dtype=dtype)
            y = torch.rand(1, 1, k_size, device=device, dtype=dtype)
            z = F.conv1d(x, y, padding='same', dilation=dilation, stride=stride)
            self.assertEqual(z.size(2), int(math.ceil(in_size / stride)))

        # Compare F.conv1d padding='same' output against manual padding
        # Without strides/dilation
        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
        y = torch.rand(1, 1, 3, device=device, dtype=dtype)
        expect = F.conv1d(x, y, padding=1)
        actual = F.conv1d(x, y, padding='same')
        self.assertEqual(expect, actual)

        # With dilation
        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
        expect = F.conv1d(x, y, padding=3, dilation=2)
        actual = F.conv1d(x, y, padding='same', dilation=2)
        self.assertEqual(expect, actual)

        # Dilation with asymmetric padding
        expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:]
        actual = F.conv1d(x, y, padding='same', dilation=3)
        self.assertEqual(expect, actual)

    @dtypes(torch.float, torch.cfloat)
    def test_conv2d_same_padding(self, device, dtype):
        if dtype is torch.cfloat:
            rtol, atol = 2e-6, 2e-6
        else:
            rtol, atol = None, None
        # Compare F.conv2d padding='same' output against manual padding
        # Without strides/dilation
        x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype)
        y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype)
        expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
        actual = F.conv2d(x, y, padding='same')
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)

        # With dilation
        y = torch.rand(1, 1, 3, 4, device=device, dtype=dtype)
        expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
        actual = F.conv2d(x, y, padding='same', dilation=2)
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)

        # Dilation with asymmetric padding
        y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype)
        expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
        actual = F.conv2d(x, y, padding='same', dilation=3)
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)

    @dtypes(torch.float, torch.cfloat)
    def test_conv3d_same_padding(self, device, dtype):
        if dtype is torch.cfloat:
            rtol, atol = 2e-6, 2e-6
        else:
            rtol, atol = None, None
        # Compare F.conv3d padding='same' output against manual padding
        # Without strides/dilation
        x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype)
        y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype)
        expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :]
        actual = F.conv3d(x, y, padding='same')
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)

        # With dilation
        expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
        actual = F.conv3d(x, y, padding='same', dilation=2)
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)

        # Dilation with asymmetric padding
        y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype)
        expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:]
        actual = F.conv3d(x, y, padding='same', dilation=3)
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)

    @dtypes(torch.float, torch.cfloat)
    def test_conv1d_valid_padding(self, device, dtype):
        # Test F.conv1d padding='valid' is the same as no padding
        x = torch.rand(1, 1, 10, device=device, dtype=dtype)
        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
        expect = F.conv1d(x, y)
        actual = F.conv1d(x, y, padding='valid')
        self.assertEqual(expect, actual)

    @dtypes(torch.float, torch.cfloat)
    def test_conv2d_valid_padding(self, device, dtype):
        # Test F.conv2d padding='valid' is the same as no padding
        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype)
        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype)
        expect = F.conv2d(x, y)
        actual = F.conv2d(x, y, padding='valid')
        self.assertEqual(expect, actual)

    @dtypes(torch.float, torch.cfloat)
    def test_conv3d_valid_padding(self, device, dtype):
        # Test F.conv3d padding='valid' is the same as no padding
        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device)
        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device)
        expect = F.conv3d(x, y)
        actual = F.conv3d(x, y, padding='valid')
        self.assertEqual(expect, actual)

    @dtypes(torch.float, torch.cfloat)
    def test_conv1d_same_padding_backward(self, device, dtype):
        # Test F.conv1d gradients work with padding='same'
        x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True)
        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)

        # Symmetric padding
        z = F.conv1d(x, y, padding=3, dilation=2)
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv1d(x, y, padding='same', dilation=2)
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)
        x.grad, y.grad = None, None

        # Asymmetric padding
        z = F.conv1d(x, y, padding=2)[..., 1:]
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv1d(x, y, padding='same')
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)

    @dtypes(torch.float, torch.cfloat)
    @tf32_on_and_off(0.001)
    def test_conv2d_same_padding_backward(self, device, dtype):
        # Test F.conv2d gradients work with padding='same'
        x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True)
        y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True)

        # Symmetric padding
        z = F.conv2d(x, y, padding=(3, 4), dilation=2)
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv2d(x, y, padding='same', dilation=2)
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)
        x.grad, y.grad = None, None

        # Asymmetric padding
        y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True)
        z = F.conv2d(x, y, padding=2)[..., 1:, 1:]
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv2d(x, y, padding='same')
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)

    @dtypes(torch.double, torch.cdouble)
    def test_conv3d_same_padding_backward(self, device, dtype):
        check_forward_ad = torch.device(device).type != 'xla'

        # Test F.conv3d gradients work with padding='same'
        x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True)
        y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True)

        # Symmetric padding
        z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv3d(x, y, padding='same', dilation=2)
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)
        x.grad, y.grad = None, None

        gradcheck(lambda x, y: F.conv3d(x, y, padding='same', dilation=2), (x, y),
                  check_forward_ad=check_forward_ad, nondet_tol=1e-5)
        if torch.device(device).type != 'cuda':
            # https://github.com/pytorch/pytorch/issues/70702
            gradgradcheck(lambda x, y: F.conv3d(x, y, padding='same', dilation=2), (x, y),
                          check_fwd_over_rev=True)

        # Asymmetric padding
        y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True)
        z = F.conv3d(x, y, padding=2)[..., 1:, 1:]
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv3d(x, y, padding='same')
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)

        gradcheck(lambda x, y: F.conv3d(x, y, padding='same'), (x, y),
                  check_forward_ad=check_forward_ad, nondet_tol=1e-5)
        if torch.device(device).type != 'cuda':
            # https://github.com/pytorch/pytorch/issues/70702
            gradgradcheck(lambda x, y: F.conv3d(x, y, padding='same'), (x, y),
                          check_fwd_over_rev=True)

    @dtypes(torch.float, torch.cfloat)
    def test_conv1d_valid_padding_backward(self, device, dtype):
        # Test F.conv1d gradients work with padding='valid'
        x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
        F.conv1d(x, y, padding=0).sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        F.conv1d(x, y, padding='valid').sum().abs().backward()
        gx_actual, gy_actual = x.grad, y.grad
        self.assertEqual(gx_expect, gx_actual)
        self.assertEqual(gy_expect, gy_actual)

    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
    @dtypes(torch.float, torch.cfloat)
    @parametrize_test("mode", ('valid', 'same'))
    def test_conv1d_vs_scipy(self, device, dtype, mode):
        t = make_tensor((1, 10), device=device, dtype=dtype)
        feat_dim = t.shape[1]
        weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype)
        weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype)

        def _test(t, weight, mode):
            # SciPy expects two 1-D inputs.
            t_a = t.view(-1).cpu().numpy()
            w_a = weight.view(-1).cpu().numpy()
            expected = scipy.signal.convolve(t_a, w_a, mode=mode)

            kwargs = {'padding': mode}
            if mode == 'same':
                # `same` padding in PyTorch conv1d is different
                # from SciPy
                p = weight.shape[2] // 2
                t = torch.nn.functional.pad(t, (p, p))
                # We have already taken care of padding
                kwargs.pop("padding")

            # second input is flipped in SciPy's convolve
            weight_flipped = torch.flip(weight, (2,))
            actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0)
            if mode == 'same':
                actual = actual[:feat_dim]

            self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5)

        # Global dtype for this test suite is torch.double
        # This leads to change in type-promotion
        # and conv1d outputs `complex128` for `complex64` input.
        with set_default_dtype(torch.float):
            _test(t, weight_even, mode)
            _test(t, weight_odd, mode)

    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
    @dtypes(torch.float, torch.cfloat)
    @parametrize_test("mode", ('valid', 'same'))
    def test_conv2d_vs_scipy(self, device, dtype, mode):
        t = make_tensor((1, 5, 10), device=device, dtype=dtype)
        weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype)
        weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype)

        def _test(t, weight, mode):
            # SciPy expects two 2-D inputs.
            t_a = t.squeeze(0).cpu().numpy()
            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
            expected = scipy.signal.convolve2d(t_a, w_a, mode=mode)

            kwargs = {'padding': mode}
            if mode == 'same':
                # `same` padding in PyTorch conv2d is different
                # from SciPy
                left_right_pad = weight.shape[3] // 2
                top_bottom_pad = weight.shape[2] // 2
                p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad)
                t = torch.nn.functional.pad(t, p)
                # We have already taken care of padding
                kwargs.pop("padding")

            # second input is flipped in SciPy's convolve2d
            weight_flipped = torch.flip(weight, (2, 3))
            actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0)
            if mode == 'same':
                actual = actual[:5, :10]

            self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)

        # Global dtype for this test suite is torch.double
        # This leads to change in type-promotion
        # and conv1d outputs `complex128` for `complex64` input.
        with set_default_dtype(torch.float):
            _test(t, weight_even, mode)
            _test(t, weight_odd, mode)

    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
    @dtypes(torch.float, torch.cfloat)
    @parametrize_test("mode", ('valid', 'same'))
    def test_conv3d_vs_scipy(self, device, dtype, mode):
        t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype)
        weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype)
        weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype)

        def _test(t, weight, mode):
            # SciPy expects two 3-D inputs.
            t_a = t.squeeze(0).cpu().numpy()
            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
            expected = scipy.signal.convolve(t_a, w_a, mode=mode)

            kwargs = {'padding': mode}
            if mode == 'same':
                # `same` padding in PyTorch conv3d is different
                # from SciPy
                left_right_pad = weight.shape[4] // 2
                top_bottom_pad = weight.shape[3] // 2
                front_back_pad = weight.shape[2] // 2
                p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad,
                     front_back_pad, front_back_pad)
                t = torch.nn.functional.pad(t, p)
                # We have already taken care of padding
                kwargs.pop("padding")

            # second input is flipped in SciPy's convolve
            weight_flipped = torch.flip(weight, (2, 3, 4))
            actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0)
            if mode == 'same':
                actual = actual[:5, :5, :10]

            if tf32_is_not_fp32() and (dtype == torch.float or dtype == torch.complex64):
                self.assertEqual(actual, expected, atol=0.05, rtol=0.05)
            else:
                self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)

        # Global dtype for this test suite is torch.double
        # This leads to change in type-promotion
        # and conv1d outputs `complex128` for `complex64` input.
        with set_default_dtype(torch.float):
            _test(t, weight_even, mode)
            _test(t, weight_odd, mode)

    @dtypes(torch.float, torch.complex64)
    def test_conv2d_valid_padding_backward(self, device, dtype):
        # Test F.conv2d gradients work with padding='valid'
        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True)
        F.conv2d(x, y, padding=0).sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        F.conv2d(x, y, padding='valid').sum().abs().backward()
        gx_actual, gy_actual = x.grad, y.grad
        self.assertEqual(gx_expect, gx_actual)
        self.assertEqual(gy_expect, gy_actual)

    @dtypes(torch.double, torch.cdouble)
    def test_conv3d_valid_padding_backward(self, device, dtype):
        check_forward_ad = torch.device(device).type != 'xla'

        # Test F.conv3d gradients work with padding='valid'
        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True)
        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True)
        F.conv3d(x, y, padding=0).sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        F.conv3d(x, y, padding='valid').sum().abs().backward()
        gx_actual, gy_actual = x.grad, y.grad
        self.assertEqual(gx_expect, gx_actual)
        self.assertEqual(gy_expect, gy_actual)

        gradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_forward_ad=check_forward_ad)
        gradgradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_fwd_over_rev=check_forward_ad)

    @parametrize_test("N", range(2, 4), name_fn=lambda N: f'ConvTranspose{N}d')
    def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
        # For inputs with no batch dim, verify output is the correct shape when output_size is set.
        # See https://github.com/pytorch/pytorch/issues/75889
        inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
        output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
        ConvTransposeNd = getattr(nn, f'ConvTranspose{N}d')
        m = ConvTransposeNd(1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device)
        output = m(inp, output_size=output_size)
        self.assertEqual(output.shape, output_size)

    @skipMeta
    @parametrize_test("input_shape,transposed,dilated,groups,layout,backend_expected", [
        # === slow ===
        subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Slow2d),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d'),
        subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_transposed'),
        subtest(((2, 6, 7), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated2d),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_dilated'),
        subtest(((2, 6, 7), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_dilated_transposed'),
        subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Slow2d),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d'),
        subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_transposed'),
        subtest(((2, 6, 7, 8), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated2d),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_dilated'),
        subtest(((2, 6, 7, 8), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_dilated_transposed'),
        subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Slow3d),
                decorators=[onlyCPU, disableMkldnn], name='slow3d_cpu'),
        # CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead
        subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.SlowDilated3d),
                decorators=[onlyCUDA, disablecuDNN], name='slow3d_cuda'),
        # FIXME: RuntimeError: CUDA out of memory.
        # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
        #         decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'),
        subtest(((2, 6, 7, 8, 9), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated3d),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated'),
        # FIXME: RuntimeError: CUDA out of memory.
        # subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
        #         decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'),
        subtest(((0, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty),
                decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch1d'),
        subtest(((2, 0, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty),
                decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel1d'),
        subtest(((0, 0, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty),
                decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel1d'),
        subtest(((0, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty),
                decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch2d'),
        subtest(((2, 0, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty),
                decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel2d'),
        subtest(((0, 0, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty),
                decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel2d'),
        subtest(((0, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty),
                decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch3d'),
        subtest(((2, 0, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty),
                decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel3d'),
        subtest(((0, 0, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty),
                decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel3d'),
        # === cuda ===
        # Note that disablecuDNN disables miopen as well.
        subtest(((2, 6, 7), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise2d),
                decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise1d'),
        subtest(((2, 6, 7, 8), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise2d),
                decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise2d'),
        subtest(((2, 6, 7, 8, 9), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise3d),
                decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise3d'),
        # === cudnn ===
        subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn),
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn1d'),
        subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn),
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn2d'),
        subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn),
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d'),
        subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose),
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn1d_transposed'),
        subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose),
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn2d_transposed'),
        # FIXME: RuntimeError: CUDA out of memory.
        # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose),
        #         decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'),
        # === miopen ===
        subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen1d'),
        subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen2d'),
        subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen3d'),
        subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen1d_transposed'),
        subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen2d_transposed'),
        subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen3d_transposed'),
        subtest(((2, 6, 7), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise1d'),
        subtest(((2, 6, 7, 8), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise2d'),
        subtest(((2, 6, 7, 8, 9), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise3d'),
        # === mkldnn ===
        subtest(((2, 6, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn1d'),
        subtest(((2, 6, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn2d'),
        subtest(((2, 6, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn3d'),
        # Transposed convolution is broken for mkldnn. See https://github.com/pytorch/pytorch/issues/68775.
        subtest(((2, 6, 7), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn),
                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn1d_transposed'),
        subtest(((2, 6, 7, 8), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn),
                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn2d_transposed'),
        subtest(((2, 6, 7, 8, 9), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn),
                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn3d_transposed'),
        subtest(((2, 6, 7), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn1d_cpu_input'),
        subtest(((2, 6, 7, 8), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn2d_cpu_input'),
        subtest(((2, 6, 7, 8, 9), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn3d_cpu_input'),
        subtest(((0, 6, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch1d'),
        subtest(((2, 0, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel1d'),
        subtest(((0, 0, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel1d'),
        subtest(((0, 6, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch2d'),
        subtest(((2, 0, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel2d'),
        subtest(((0, 0, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel2d'),
        subtest(((0, 6, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch3d'),
        subtest(((2, 0, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel3d'),
        subtest(((0, 0, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty),
                decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel3d'),
        # Note: Tests for mobile backends are not currently supported. This comprises
        # NnpackSpatial, Winograd3x3Depthwise, and Xnnpack2d backends. Testing these
        # requires the ability to gate tests by whether PyTorch is built with USE_MOBILE=1.
    ])
    # Test with both bias and no bias.
    @parametrize_test("has_bias", [False, True])
    # Test with both stride=1 and stride>1 cases.
    @parametrize_test("strided", [False, True])
    # Test with both contiguous and non-contiguous inputs.
    @parametrize_test("contiguous", [False, True])
    def test_conv_backend(
            self, device, input_shape, has_bias, strided, contiguous, transposed, dilated, groups,
            layout, backend_expected):
        # Build up inputs.
        dtype = torch.float32
        C_in, C_out, dim, kernel_size = input_shape[1], 12, len(input_shape) - 2, 3
        x = torch.randn(*input_shape, device=device, dtype=dtype, requires_grad=True)
        weight = torch.randn(C_in if transposed else C_out,
                             C_out // groups if transposed else C_in // groups,
                             *[kernel_size for _ in range(dim)],
                             device=device, dtype=dtype, requires_grad=True)
        bias = torch.randn(C_out, device=device, dtype=dtype, requires_grad=True) if has_bias else None

        def _make_noncontiguous(inp):
            if inp is None:
                return None
            old_requires_grad = inp.requires_grad
            inp = torch.repeat_interleave(inp, 2, dim=-1)
            inp = inp[..., ::2].detach().requires_grad_(old_requires_grad)
            return inp

        if not contiguous:
            x = _make_noncontiguous(x)
            weight = _make_noncontiguous(weight)
            bias = _make_noncontiguous(bias)

        if layout is torch._mkldnn:
            x = x.to_mkldnn()
            # Note that weight and bias are not supported as mkldnn tensors during training.

        stride = (2,) * dim if strided else (1,) * dim
        padding = (0,) * dim
        dilation = (2,) * dim if dilated else (1,) * dim
        output_padding = (0,) * dim
        inputs = [x, weight, bias, stride, padding, dilation, transposed, output_padding, groups]

        # Ensure correct backend is selected.
        backend_actual = torch._C._select_conv_backend(*inputs)
        self.assertEqual(backend_actual, backend_expected)

        # Ensure backward call succeeds.
        convolution = torch.ops.aten.convolution
        output = convolution(*inputs)
        grad_output = torch.randn(output.shape, device=device, dtype=dtype)
        if not contiguous:
            grad_output = _make_noncontiguous(grad_output)
        if layout is torch._mkldnn:
            grad_output = grad_output.to_mkldnn()
        output.backward(grad_output)

        # mkldnn doesn't support gradcheck :(
        if layout is torch._mkldnn:
            return

        if backend_actual != torch._C._ConvBackend.Empty:  # FIXME: forward AD fails
            # Forward AD and forward-over-reverse AD smoke test in float32
            # TODO: remove this if we introduce per-op gradient tests for float32
            with fwAD.dual_level():
                dual_inputs = [(fwAD.make_dual(i, torch.rand_like(i)) if isinstance(i, torch.Tensor) else i)
                               for i in inputs]
                # Forward AD
                output = convolution(*dual_inputs)
                # Forward over reverse AD
                grad_output_d = fwAD.make_dual(torch.rand_like(output), torch.rand_like(output))
                if has_bias:
                    torch.autograd.grad(output, [x, weight, bias], grad_output_d)
                else:
                    torch.autograd.grad(output, [x, weight], grad_output_d)

        # Convert to float64 for gradcheck.
        x = x.to(torch.float64).detach().requires_grad_(True)
        weight = weight.to(torch.float64).detach().requires_grad_(True)
        if bias is not None:
            bias = bias.to(torch.float64).detach().requires_grad_(True)
        inputs = [x, weight, bias, stride, padding, dilation, transposed, output_padding, groups]

        # Set some backend-specific validation settings.
        gradcheck_nondet_tol = 0.0
        if torch.backends.cudnn.is_available():
            # cuDNN introduces non-determinism
            gradcheck_nondet_tol = GRADCHECK_NONDET_TOL

        self.assertTrue(gradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol))

        # double backward doesn't support bias gradients
        if bias is not None:
            bias.requires_grad_(False)
        self.assertTrue(gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol))

    @onlyCPU
    def test_conv_contiguous_for_oneDNN(self):
        # See https://github.com/pytorch/pytorch/issues/80837.
        for dtype in [torch.float, torch.bfloat16, torch.half]:
            conv = nn.Conv2d(
                1,
                128,
                kernel_size=(5, 2),
                stride=(2, 1),
                padding=(0, 1),
                dilation=(1, 1),
                groups=1,
                bias=True,
                padding_mode='zeros').to(dtype=dtype)

            x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype)
            x = torch.transpose(x, 1, 4)
            x2 = x[..., 0]
            inputs = [x2, conv.weight, conv.bias, (2, 1), (0, 1), (1, 1), False, (0, 1), 1]
            if torch.backends.mkldnn.is_available():
                y = conv(x2)
                # Disable MKLDNN explicitly
                with torch.backends.mkldnn.flags(enabled=False):
                    y_ = conv(x2)
                    self.assertEqual(y, y_)

    @onlyCPU
    def test_conv_ic1_channels_last_for_oneDNN(self):
        # See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path.
        for dtype in [torch.float, torch.bfloat16, torch.half]:
            conv = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False)
            conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype)
            x = torch.rand(2, 1, 100, 100).to(dtype=dtype)
            if torch.backends.mkldnn.is_available():
                y = conv(x)
                # Disable MKLDNN explicitly
                with torch.backends.mkldnn.flags(enabled=False):
                    y_ = conv(x)
                    self.assertEqual(y, y_)

    @dtypes(torch.float, torch.cfloat)
    def test_conv_empty_channel(self, device, dtype):
        in_channels = 0
        mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device)
        inp = torch.randn(2, 0, 15, device=device, dtype=dtype)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
            inp = torch.randn(2, 1, 0, device=device, dtype=dtype)
            mod(inp)

        mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
        inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
            inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype)
            mod(inp)

        mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
        inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
            inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype)
            mod(inp)

    def test_group_conv_empty(self, device):
        mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device)
        inp = torch.randn(0, 4, 4, 4, device=device)
        _test_module_empty_input(self, mod, inp, check_size=False)
        if self.device_type == 'cuda' and self.has_cudnn():
            with torch.backends.cudnn.flags(enabled=False):
                _test_module_empty_input(self, mod, inp, check_size=False)

    def test_group_convTranspose_empty(self, device):
        mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device)
        inp = torch.randn(0, 4, 4, 4, device=device)
        _test_module_empty_input(self, mod, inp, check_size=False)
        if self.device_type == 'cuda' and self.has_cudnn():
            with torch.backends.cudnn.flags(enabled=False):
                _test_module_empty_input(self, mod, inp, check_size=False)

    def test_convTranspose_empty(self, device):
        mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(device)
        inp = torch.randn(0, 4, 4, 4, device=device)
        _test_module_empty_input(self, mod, inp, check_size=False)
        if self.device_type == 'cuda' and self.has_cudnn():
            with torch.backends.cudnn.flags(enabled=False):
                _test_module_empty_input(self, mod, inp, check_size=False)

    @onlyCUDA
    @largeTensorTest('12GB')
    def test_conv_large_nosplit(self, device):
        # Here we just test the convolution correctly route to the fallback implementation
        # that is, it does not crash. The correctness of fallback implementation should be
        # covered in other tests
        dtype = torch.half if self.device_type == 'cuda' else torch.float
        conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype)
        input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device)
        conv1(input_large)
        conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype)
        input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
        conv2(input_large)

    def test_conv_noncontig_weights(self, device):
        for dim in (1, 2, 3):
            for grouped in (False, True):
                nc = 3
                groups = 3 if grouped else 1
                w = torch.randn([3] * dim, device=device)
                w = w.expand([nc, int(nc / groups)] + list(w.shape))
                w = w.detach().requires_grad_()
                x = torch.randn([1, nc] + ([5] * dim), device=device, requires_grad=True)
                y = getattr(F, f'conv{dim}d')(x, w, groups=groups)
                y.sum().backward()
                y = getattr(F, f'conv_transpose{dim}d')(x, w, groups=groups)
                y.sum().backward()

    def test_conv_noncontig_weights_and_bias(self, device):
        # need floats to exercise https://github.com/pytorch/pytorch/issues/16018
        for bias in [True, False]:
            conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                              bias=bias).to(device, torch.float)

            input_nc = torch.randn((1, 3, 224, 224, 2), device=device, dtype=torch.float)[:, :, :, :, 1]
            input_c = input_nc.contiguous()

            weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[:, :, :, :, 1]
            conv1.weight = nn.Parameter(weight_nc)
            weight_c = conv1.weight.contiguous()

            if bias:
                bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1]
                conv1.bias = nn.Parameter(bias_nc)
                bias_c = conv1.bias.contiguous()

            out1 = conv1(input_nc)
            conv1.weight = nn.Parameter(weight_c)
            if bias:
                conv1.bias = nn.Parameter(bias_c)
            out2 = conv1(input_c)
            self.assertEqual(out1, out2)

    @onlyCUDA
    @largeTensorTest('12GB')
    @skipIfRocmVersionLessThan((6, 0))
    def test_conv_transposed_large(self, device):
        dtype = torch.half if self.device_type == 'cuda' else torch.float
        conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
        input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device)
        # forward
        ret = conv(input_large)
        maxdiff0 = (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024))).abs_().max().item()
        maxdiff1 = (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024))).abs_().max().item()
        maxdiff2 = (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024))).abs_().max().item()
        maxdiff3 = (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024))).abs_().max().item()
        if self.device_type == 'cuda':
            # cuDNN may use algorithms such as FFT that don't guarantee a diff of 0
            self.assertEqual(maxdiff0, 0, atol=2e-3, rtol=1e-5)
            self.assertEqual(maxdiff1, 0, atol=2e-3, rtol=1e-5)
            self.assertEqual(maxdiff2, 0, atol=2e-3, rtol=1e-5)
            self.assertEqual(maxdiff3, 0, atol=2e-3, rtol=1e-5)
        else:
            self.assertEqual(maxdiff0, 0)
            self.assertEqual(maxdiff1, 0)
            self.assertEqual(maxdiff2, 0)
            self.assertEqual(maxdiff3, 0)

    @onlyCUDA
    @skipCUDAIfRocm
    @largeTensorTest('12GB')
    def test_conv_large(self, device):
        dtype = torch.half if self.device_type == 'cuda' else torch.float
        conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype)
        input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device)
        # forward
        ret = conv(input_large)
        self.assertEqual(ret[:2048], conv(input_large[:2048]))
        self.assertEqual(ret[2048:4096], conv(input_large[2048:4096]))
        self.assertEqual(ret[4096:], conv(input_large[4096:]))

        # backward
        conv.zero_grad()
        # When computing the backward, we are using the `max(dim=1)`` to create
        # some sparsity. Without this sparsity, the rounding error would be
        # too large (as large as 1e-5) to satisfy the creterion (1e-6) of `assertEqual`
        ret.view(4097, -1).max(dim=1).values.sum().backward()
        del ret
        grad1 = conv.weight.grad.detach().clone()
        conv.zero_grad()
        conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward()
        conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward()
        conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward()
        grad2 = conv.weight.grad.detach().clone()
        # gradients are at the order of hundreds, we need to scale it to
        # the order of one so that we can compare
        scale = 1 / grad2.abs().mean()
        grad1 = grad1 * scale
        grad2 = grad2 * scale
        self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)

    @onlyCUDA
    @skipCUDAIfNoCudnn
    def test_contig_wrong_stride_cudnn(self, device):
        # x has to have batch_size 1 to test contiguous checks
        x = torch.randn(1, 16, 5, 5, device=device)
        stride = list(x.stride())
        stride[0] = 20
        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
        x.set_(x.storage(), 0, x.size(), stride)
        self.assertTrue(x.is_contiguous())
        F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device))
        F.conv2d(x, torch.randn(1, 16, 1, 1, device=device))

    @onlyCUDA
    def test_Conv2d_size_1_kernel(self, device):
        x_cpu = torch.randn(2, 3, 5, 5)
        conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1)
        y_cpu = conv_cpu(x_cpu)
        y = torch.rand_like(y_cpu)
        y_cpu.backward(y)

        with cudnn.flags(enabled=False):
            conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device)
            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
            y_cuda = conv_cuda(x_cpu.to(device))
            y_cuda.backward(y.to(device))

        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
        self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False)
        self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False)

    @onlyCUDA
    def test_ConvTranspose2d_size_1_kernel(self, device):
        x_cpu = torch.randn(2, 3, 5, 5)
        conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1)
        y_cpu = conv_cpu(x_cpu)
        y = torch.rand_like(y_cpu)
        y_cpu.backward(y)

        with cudnn.flags(enabled=False):
            conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device)
            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
            y_cuda = conv_cuda(x_cpu.to(device))
            y_cuda.backward(y.to(device))

        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
        self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False)
        self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False)

    @onlyCUDA
    def test_ConvTranspose3d_size_1_kernel(self, device):
        with set_default_dtype(torch.double):
            x_cpu = torch.randn(2, 3, 3, 5, 5)
            conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1)
            y_cpu = conv_cpu(x_cpu)
            y = torch.rand_like(y_cpu)
            y_cpu.backward(y)

            with cudnn.flags(enabled=False):
                conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device)
                conv_cuda.bias.data.copy_(conv_cpu.bias.data)
                conv_cuda.weight.data.copy_(conv_cpu.weight.data)
                y_cuda = conv_cuda(x_cpu.to(device))
                y_cuda.backward(y.to(device))

            self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
            self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False)
            self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data,
                             atol=1e-5, rtol=0, exact_device=False)

    @dtypesIfCUDA(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
    @dtypes(torch.float)
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
    @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
    def test_Conv2d_naive_groups(self, device, dtype):
        # Check that grouped convolutions matches two half convolutions
        m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
        i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
        output = m(i)
        grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
        output.backward(grad_output)

        m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
        m1.weight.data.copy_(m.weight.data[:2])
        m1.bias.data.copy_(m.bias.data[:2])
        i1 = i.data[:, :2].contiguous().requires_grad_(True)
        output1 = m1(i1)
        output1.backward(grad_output[:, :2].contiguous())

        m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
        m2.weight.data.copy_(m.weight.data[2:])
        m2.bias.data.copy_(m.bias.data[2:])
        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
        output2 = m2(i2)
        output2.backward(grad_output[:, 2:].contiguous())

        self.assertEqual(output, torch.cat([output1, output2], 1))
        self.assertEqual(i.grad.data,
                         torch.cat([i1.grad.data, i2.grad.data], 1),
                         atol=dtype2prec_DONTUSE[dtype], rtol=0)
        self.assertEqual(m.bias.grad.data,
                         torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
                         atol=dtype2prec_DONTUSE[dtype], rtol=0)
        self.assertEqual(m.weight.grad.data,
                         torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
                         atol=dtype2prec_DONTUSE[dtype], rtol=0)

    @dtypes(torch.double, torch.cdouble)
    def test_Conv2d_backward_depthwise(self, device, dtype):
        x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True)
        weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True)

        def conv2d_depthwise(x, weight):
            return torch.nn.functional.conv2d(
                x, weight, bias=None, stride=(1, 10), groups=2)

        for cudnn_enabled in [False, True]:
            with torch.backends.cudnn.flags(enabled=cudnn_enabled):
                torch.autograd.gradcheck(conv2d_depthwise, (x, weight))

    @onlyCPU
    @dtypes(torch.float, torch.double)
    def test_conv_thnn_nhwc(self, device, dtype):
        def helper(mod, n, c, h, w, out_channels, kernel_size, dilation, groups, input_format, weight_format):
            input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device)\
                .to(memory_format=input_format)
            input.requires_grad_()
            conv = mod(c, out_channels, kernel_size, dilation=dilation, groups=groups)\
                .to(device='cpu', dtype=dtype, memory_format=weight_format)
            for p in conv.parameters():
                p.data = torch.randint_like(p, -3, 3)

            ref_input = input.detach().clone().contiguous().requires_grad_()
            ref_conv = mod(c, out_channels, kernel_size, dilation=dilation, groups=groups)
            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
            ref_conv.load_state_dict(conv.state_dict())
            ref_conv = ref_conv.to(device='cpu', dtype=dtype, memory_format=torch.contiguous_format)

            out = conv(input)
            ref_out = ref_conv(ref_input)

            grad = torch.randint_like(out, -3, 3)
            ref_grad = grad.detach().clone().contiguous()

            out.backward(grad)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_out.is_contiguous())
            self.assertEqual(out, ref_out, exact_dtype=False)
            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)

        with torch.backends.mkldnn.flags(enabled=False):
            formats = [[torch.channels_last, torch.channels_last],
                       [torch.channels_last, torch.contiguous_format],
                       [torch.contiguous_format, torch.channels_last]]
            for input_format, weight_format in formats:
                # non-dilated conv: thnn_conv2d normal path (with im2col)
                helper(nn.Conv2d, 2, 8, 4, 4, out_channels=4, kernel_size=3, dilation=1, groups=1,
                       input_format=input_format, weight_format=weight_format)
                helper(nn.Conv2d, 2, 8, 4, 4, out_channels=8, kernel_size=3, dilation=1, groups=8,
                       input_format=input_format, weight_format=weight_format)
                # test when input chanels is 1 and not converted to channels last
                helper(nn.Conv2d, 2, 1, 10, 10, out_channels=8, kernel_size=3, dilation=1, groups=1,
                       input_format=torch.contiguous_format, weight_format=torch.channels_last)
                # non-dilated conv: thnn_conv2d fast path (skip im2col)
                helper(nn.Conv2d, 1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=1,
                       input_format=input_format, weight_format=weight_format)
                # ic == oc == 1 here, so need to stick input to CL to activate channels last
                helper(nn.Conv2d, 1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=16,
                       input_format=torch.channels_last, weight_format=weight_format)
                # dilated conv: slow_conv_dilated2d
                helper(nn.Conv2d, 2, 8, 11, 13, out_channels=16, kernel_size=3, dilation=2, groups=1,
                       input_format=input_format, weight_format=weight_format)
                helper(nn.Conv2d, 2, 16, 11, 13, out_channels=32, kernel_size=3, dilation=2, groups=16,
                       input_format=input_format, weight_format=weight_format)
                # transposed-conv: slow_conv_transpose2d
                helper(nn.ConvTranspose2d, 2, 8, 4, 4, out_channels=4, kernel_size=3, dilation=1, groups=1,
                       input_format=input_format, weight_format=weight_format)
                helper(nn.ConvTranspose2d, 2, 8, 4, 4, out_channels=8, kernel_size=3, dilation=1, groups=8,
                       input_format=input_format, weight_format=weight_format)
                helper(nn.ConvTranspose2d, 1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=1,
                       input_format=input_format, weight_format=weight_format)
                helper(nn.ConvTranspose2d, 1, 16, 56, 56, out_channels=32, kernel_size=1, dilation=1, groups=16,
                       input_format=input_format, weight_format=weight_format)

    @onlyCUDA
    @skipCUDAIfRocmVersionLessThan((4, 3))
    @skipCUDAIfNotMiopenSuggestNHWC
    @skipCUDAIfCudnnVersionLessThan(7603)
    @dtypes(torch.half, torch.float, torch.cfloat)
    def test_conv_cudnn_nhwc(self, device, dtype):
        def helper(n, c, h, w, out_channels, kernel_size, groups):
            input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device)\
                .to(memory_format=torch.channels_last)
            input.requires_grad_()
            conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)\
                .to(device='cuda', dtype=dtype, memory_format=torch.channels_last)
            for p in conv.parameters():
                p.data = torch.randint_like(p, -3, 3)

            # use FP64 channels-first conv as reference
            ref_input = input.detach().clone().contiguous().double().requires_grad_()
            ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)
            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
            ref_conv.load_state_dict(conv.state_dict())
            ref_conv = ref_conv.to(device='cuda', dtype=torch.double, memory_format=torch.contiguous_format)

            out = conv(input)
            ref_out = ref_conv(ref_input)

            grad = torch.randint_like(out, -3, 3)
            ref_grad = grad.detach().clone().double().contiguous()

            out.backward(grad)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(conv.weight.grad.is_contiguous(memory_format=torch.channels_last))

            self.assertTrue(ref_out.is_contiguous())
            self.assertTrue(ref_input.grad.is_contiguous())
            self.assertTrue(ref_conv.weight.grad.is_contiguous())

            self.assertEqual(out, ref_out, exact_dtype=False)
            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)

        helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1)
        helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8)
        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1)
        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)

    @onlyCUDA
    @skipCUDAIfRocm
    @skipCUDAIfCudnnVersionLessThan(8005)
    @dtypes(torch.half, torch.float)
    def test_conv_cudnn_ndhwc(self, device, dtype):
        def helper(n, c, d, h, w, out_channels, kernel_size, groups):
            input = torch.randint(-2, 2, (n, c, d, h, w), dtype=dtype, device=device)\
                .to(memory_format=torch.channels_last_3d)
            input.requires_grad_()
            conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)\
                .to(device='cuda', dtype=dtype, memory_format=torch.channels_last_3d)
            for p in conv.parameters():
                p.data = torch.randint_like(p, -2, 2)

            # use FP64 channels-first conv as reference
            ref_input = input.detach().clone().contiguous().double().requires_grad_()
            ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)
            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
            ref_conv.load_state_dict(conv.state_dict())
            ref_conv = ref_conv.to(device='cuda', dtype=torch.double, memory_format=torch.contiguous_format)

            out = conv(input)
            ref_out = ref_conv(ref_input)

            grad = torch.randint_like(out, -2, 2)
            ref_grad = grad.detach().clone().double().contiguous()

            out.backward(grad)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
            self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last_3d))
            self.assertTrue(conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d))

            self.assertTrue(ref_out.is_contiguous())
            self.assertTrue(ref_input.grad.is_contiguous())
            self.assertTrue(ref_conv.weight.grad.is_contiguous())

            self.assertEqual(out, ref_out, exact_dtype=False)
            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)

        helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1)
        helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8)
        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1)
        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16)

    def _run_conv(self, layer, device, inp, grad, ref_conv, ref_input, ref_out,
                  input_format, weight_format, grad_format, output_format):
        conv = layer(inp.size(1), grad.size(1),
                     ref_conv.weight.size(2)).float().to(device)
        # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
        conv.load_state_dict(ref_conv.state_dict())
        weight_data = conv.weight.detach().clone().contiguous(memory_format=weight_format)
        conv.weight.data = weight_data.resize_(weight_data.size(), memory_format=weight_format)
        input = inp.clone().contiguous(memory_format=input_format)
        input.resize_(input.size(), memory_format=input_format)
        input = input.requires_grad_()
        grad = grad.contiguous(memory_format=grad_format)
        grad.resize_(grad.size(), memory_format=grad_format)
        out = conv(input)
        out.backward(grad)
        self.assertTrue(out.is_contiguous(memory_format=output_format))
        self.assertEqual(out, ref_out)
        self.assertEqual(conv.weight.grad, ref_conv.weight.grad)
        self.assertEqual(conv.bias.grad, ref_conv.bias.grad)
        self.assertEqual(input.grad, ref_input.grad)

    def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
        data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device)
        ref_input = data.clone().contiguous().requires_grad_(True)
        ref_conv = layer(c, k, filter_size).float().to(device)
        ref_out = ref_conv(ref_input)
        grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device="cuda")
        ref_out.backward(grad)

        for w_f in [torch.contiguous_format, torch.channels_last]:
            for g_f in [torch.contiguous_format, torch.channels_last]:
                for input_format in [torch.contiguous_format, torch.channels_last]:
                    output_format = torch.contiguous_format
                    # Older versions of CudNN have Channels Last support disabled
                    if torch.backends.cudnn.version() >= 7603:
                        if input_format == torch.channels_last:
                            output_format = torch.channels_last
                        # This is because we have N111 weight that cannot handle
                        # the ambiguous memory_format
                        if w_f == torch.channels_last:
                            if layer == nn.Conv2d and filter_size * c != 1:
                                output_format = torch.channels_last
                            if layer == nn.ConvTranspose2d and filter_size * k != 1:
                                output_format = torch.channels_last
                    self._run_conv(layer, device, data, grad, ref_conv, ref_input,
                                   ref_out, input_format, w_f, g_f, output_format)

    @onlyCUDA
    @skipCUDAIfRocmVersionLessThan((4, 3))
    @skipCUDAIfNotMiopenSuggestNHWC
    @skipCUDAIfCudnnVersionLessThan(7603)
    @tf32_on_and_off(0.05)
    def test_conv_cudnn_mismatch_memory_format(self, device):
        configs = [
            [4, 2, 8, 8, 4, 2],
            [4, 1, 8, 8, 4, 2],
            [1, 1, 8, 8, 4, 2],
            [4, 2, 2, 8, 4, 1],
            [4, 2, 1, 8, 4, 1],
            [4, 2, 8, 8, 4, 1],
            [4, 1, 8, 8, 4, 1],
        ]
        for n, c, h, w, k, filter_size in configs:
            self._test_conv_cudnn_nhwc_nchw(nn.Conv2d, n, c, h, w, k, filter_size, device)
            self._test_conv_cudnn_nhwc_nchw(nn.ConvTranspose2d, n, c, h, w, k, filter_size, device)

    # torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4
    # returning CUDNN_STATUS_BAD_PARAM
    # Disabling that specific test for now [see issue # 33918]
    @onlyCUDA
    @skipCUDAIfNoCudnn
    @dtypes(torch.float, torch.double)
    def test_conv_cudnn_nhwc_support(self, device, dtype):
        input = torch.randn((1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True)
        weight = torch.randn((8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True)
        weight = weight.to(memory_format=torch.channels_last)
        o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
        self.assertTrue(o.is_contiguous(memory_format=torch.channels_last))
        o.sum().backward()

    # Test that faster algorithms used for inference produce the same results
    # Validates depthwise3x3 bug reported in https://github.com/pytorch/pytorch/issues/60176
    @onlyCPU
    @dtypes(torch.float)
    def test_conv2d_no_grad(self, device, dtype):
        for batch in [1, 2, 3]:
            for groups in [1, 2, 4]:
                input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device)
                m = nn.Conv2d(groups, 8, kernel_size=(3, 3), groups=groups, dtype=dtype, device=device)
                with torch.no_grad():
                    output_ng = m(input)
                output = m(input)
                self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5)

    @onlyCUDA
    @skipCUDAIfNoCudnn
    @dtypes(torch.float, torch.float16)
    @precisionOverride({torch.half: 0.002, torch.float: 1e-4})
    def test_cudnn_convolution_relu(self, device, dtype):
        for batch, groups, image_size, kernel_size, memory_format in \
                product((1, 2, 3),
                        (1, 2, 4),
                        ((1, 1), (8, 8)),
                        ((1, 1), (3, 3)),
                        (torch.channels_last, torch.contiguous_format)):
            if image_size[0] < kernel_size[0]:
                continue
            inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
            w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
            conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
            inp = inp.to(memory_format=memory_format)
            w = w.to(memory_format=memory_format)
            if torch.version.hip:
                cudnn_out = torch.miopen_convolution_relu(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
            else:
                cudnn_out = torch.cudnn_convolution_relu(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
            self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
            if tf32_is_not_fp32() and dtype == torch.float:
                self.assertEqual(conv2d_out.relu(), cudnn_out, atol=4e-3, rtol=0.006)
            else:
                self.assertEqual(conv2d_out.relu(), cudnn_out)

    @onlyCUDA
    @skipCUDAIfNoCudnn
    @dtypes(torch.float, torch.float16)
    @precisionOverride({torch.half: 0.002, torch.float: 1e-4})
    def test_cudnn_convolution_add_relu(self, device, dtype):
        for batch, groups, image_size, kernel_size, memory_format in \
            product((1, 2, 3),
                    (1, 2, 4),
                    ((1, 1), (8, 8)),
                    ((1, 1), (3, 3)),
                    (torch.channels_last, torch.contiguous_format)):
            if image_size[0] < kernel_size[0]:
                continue
            inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
            w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
            conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
            alpha = 2.0
            z = torch.randn_like(conv2d_out)

            inp = inp.to(memory_format=memory_format)
            w = w.to(memory_format=memory_format)
            z = z.to(memory_format=memory_format)
            if torch.version.hip:
                cudnn_out = torch.miopen_convolution_add_relu(inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1)
            else:
                cudnn_out = torch.cudnn_convolution_add_relu(inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1)

            self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
            if tf32_is_not_fp32() and dtype == torch.float:
                self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out, atol=2e-3, rtol=0.006)
            else:
                self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)

    @onlyCUDA
    @skipCUDAIfRocm
    @skipCUDAIfCudnnVersionLessThan(7603)
    def test_convert_conv2d_weight_memory_format(self, device):
        input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
        model = nn.Sequential(
            nn.Conv2d(8, 4, 3),
            nn.BatchNorm2d(4)).to(device).float()
        for memory_format in [torch.channels_last, torch.contiguous_format]:
            model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
            out = model(input)
            self.assertTrue(out.is_contiguous(memory_format=memory_format))

        model = nn.Sequential(
            nn.ConvTranspose2d(8, 4, 3),
            nn.BatchNorm2d(4)).to(device).float()
        for memory_format in [torch.channels_last, torch.contiguous_format]:
            model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
            out = model(input)
            self.assertTrue(out.is_contiguous(memory_format=memory_format))

    @onlyCUDA
    @skipCUDAIfRocm
    @skipCUDAIfCudnnVersionLessThan(7603)
    def test_convert_conv3d_weight_memory_format(self, device):

        input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device)
        model = nn.Sequential(
            nn.ConvTranspose3d(8, 4, 3),
            nn.BatchNorm3d(4)).to(device).float()
        for memory_format in [torch.channels_last_3d, torch.contiguous_format]:
            model = nn.utils.convert_conv3d_weight_memory_format(model, memory_format)
            out = model(input)
            self.assertTrue(out.is_contiguous(memory_format=memory_format))

    def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
        # Test that _convolution_double_backward() outputs the correct grad shapes
        # for 3D input / weight when stride > 1. This is an ad-hoc regression test for a
        # specific case that was uncovered during the convolution consolidation effort.
        # The test can be safely deleted if _convolution_double_backward() is removed.

        input = torch.randn(2, 3, 6, device=device)
        weight = torch.randn(3, 3, 3, device=device)
        bias = torch.randn(3, device=device)
        stride = (2,)
        padding = (1,)
        dilation = (1,)
        transposed = False
        output_padding = (0,)
        groups = 1
        output = torch.ops.aten.convolution(input, weight, bias, stride, padding, dilation, transposed,
                                            output_padding, groups)

        ggI = torch.randn(input.shape, device=device)
        ggW = torch.randn(weight.shape, device=device)
        ggB = torch.randn(bias.shape, device=device)
        gO = torch.randn(output.shape, device=device)
        output_mask = [True, True, True]
        grad_grad_output, grad_input, grad_weight = torch.ops.aten._convolution_double_backward(
            ggI, ggW, ggB, gO, weight, input, stride, padding, dilation, transposed,
            output_padding, groups, output_mask)

        # Make sure the correct shapes are computed.
        self.assertEqual(grad_grad_output.shape, gO.shape)
        self.assertEqual(grad_input.shape, input.shape)
        self.assertEqual(grad_weight.shape, weight.shape)

    @onlyCUDA
    @largeTensorTest('40GB')
    @largeTensorTest('24GB', 'cpu')
    def test_conv3d_64bit_indexing(self, device):
        x = torch.rand(1, 32, 512, 512, 256)
        m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False)
        yref = m(x)
        y = m.to(device=device)(x.to(device=device))
        self.assertEqual(yref, y)


instantiate_device_type_tests(TestConvolutionNNDeviceType, globals())
instantiate_parametrized_tests(TestConvolutionNN)

if __name__ == '__main__':
    run_tests()
